// Copyright (c) 2018-2024 The Dash Core developers
// Distributed under the MIT/X11 software license, see the accompanying
// file COPYING or http://www.opensource.org/licenses/mit-license.php.

#include <llmq/signing.h>

#include <llmq/commitment.h>
#include <llmq/params.h>
#include <llmq/quorums.h>
#include <llmq/signhash.h>

#include <chainparams.h>
#include <dbwrapper.h>
#include <streams.h>
#include <util/system.h>

#include <algorithm>
#include <unordered_map>
#include <unordered_set>

namespace llmq
{
CRecoveredSigsDb::CRecoveredSigsDb(const util::DbWrapperParams& db_params) :
    db{util::MakeDbWrapper({db_params.path / "llmq" / "recsigdb", db_params.memory, db_params.wipe, /*cache_size=*/8 << 20})}
{
}

CRecoveredSigsDb::~CRecoveredSigsDb() = default;

bool CRecoveredSigsDb::HasRecoveredSig(Consensus::LLMQType llmqType, const uint256& id, const uint256& msgHash) const
{
    auto k = std::make_tuple(std::string("rs_r"), llmqType, id, msgHash);
    return db->Exists(k);
}

bool CRecoveredSigsDb::HasRecoveredSigForId(Consensus::LLMQType llmqType, const uint256& id) const
{
    auto cacheKey = std::make_pair(llmqType, id);
    bool ret;
    {
        LOCK(cs_cache);
        if (hasSigForIdCache.get(cacheKey, ret)) {
            return ret;
        }
    }


    auto k = std::make_tuple(std::string("rs_r"), llmqType, id);
    ret = db->Exists(k);

    LOCK(cs_cache);
    hasSigForIdCache.insert(cacheKey, ret);
    return ret;
}

bool CRecoveredSigsDb::HasRecoveredSigForSession(const uint256& signHash) const
{
    bool ret;
    {
        LOCK(cs_cache);
        if (hasSigForSessionCache.get(signHash, ret)) {
            return ret;
        }
    }

    auto k = std::make_tuple(std::string("rs_s"), signHash);
    ret = db->Exists(k);

    LOCK(cs_cache);
    hasSigForSessionCache.insert(signHash, ret);
    return ret;
}

bool CRecoveredSigsDb::HasRecoveredSigForHash(const uint256& hash) const
{
    bool ret;
    {
        LOCK(cs_cache);
        if (hasSigForHashCache.get(hash, ret)) {
            return ret;
        }
    }

    auto k = std::make_tuple(std::string("rs_h"), hash);
    ret = db->Exists(k);

    LOCK(cs_cache);
    hasSigForHashCache.insert(hash, ret);
    return ret;
}

bool CRecoveredSigsDb::ReadRecoveredSig(Consensus::LLMQType llmqType, const uint256& id, CRecoveredSig& ret) const
{
    auto k = std::make_tuple(std::string("rs_r"), llmqType, id);

    CDataStream ds(SER_DISK, CLIENT_VERSION);
    if (!db->ReadDataStream(k, ds)) {
        return false;
    }

    try {
        ret.Unserialize(ds);
        return true;
    } catch (std::exception&) {
        return false;
    }
}

bool CRecoveredSigsDb::GetRecoveredSigByHash(const uint256& hash, CRecoveredSig& ret) const
{
    auto k1 = std::make_tuple(std::string("rs_h"), hash);
    std::pair<Consensus::LLMQType, uint256> k2;
    if (!db->Read(k1, k2)) {
        return false;
    }

    return ReadRecoveredSig(k2.first, k2.second, ret);
}

bool CRecoveredSigsDb::GetRecoveredSigById(Consensus::LLMQType llmqType, const uint256& id, CRecoveredSig& ret) const
{
    return ReadRecoveredSig(llmqType, id, ret);
}

void CRecoveredSigsDb::WriteRecoveredSig(const llmq::CRecoveredSig& recSig)
{
    CDBBatch batch(*db);

    uint32_t curTime = GetTime<std::chrono::seconds>().count();

    // we put these close to each other to leverage leveldb's key compaction
    // this way, the second key can be used for fast HasRecoveredSig checks while the first key stores the recSig
    auto k1 = std::make_tuple(std::string("rs_r"), recSig.getLlmqType(), recSig.getId());
    auto k2 = std::make_tuple(std::string("rs_r"), recSig.getLlmqType(), recSig.getId(), recSig.getMsgHash());
    batch.Write(k1, recSig);
    // this key is also used to store the current time, so that we can easily get to the "rs_t" key when we have the id
    batch.Write(k2, curTime);

    // store by object hash
    auto k3 = std::make_tuple(std::string("rs_h"), recSig.GetHash());
    batch.Write(k3, std::make_pair(recSig.getLlmqType(), recSig.getId()));

    // store by signHash
    auto signHash = recSig.buildSignHash();
    auto k4 = std::make_tuple(std::string("rs_s"), signHash.Get());
    batch.Write(k4, (uint8_t)1);

    // store by current time. Allows fast cleanup of old recSigs
    auto k5 = std::make_tuple(std::string("rs_t"), (uint32_t)htobe32_internal(curTime), recSig.getLlmqType(), recSig.getId());
    batch.Write(k5, (uint8_t)1);

    db->WriteBatch(batch);

    {
        LOCK(cs_cache);
        hasSigForIdCache.insert(std::make_pair(recSig.getLlmqType(), recSig.getId()), true);
        hasSigForSessionCache.insert(signHash.Get(), true);
        hasSigForHashCache.insert(recSig.GetHash(), true);
    }
}

void CRecoveredSigsDb::RemoveRecoveredSig(CDBBatch& batch, Consensus::LLMQType llmqType, const uint256& id, bool deleteHashKey, bool deleteTimeKey)
{
    CRecoveredSig recSig;
    if (!ReadRecoveredSig(llmqType, id, recSig)) {
        return;
    }

    auto signHash = recSig.buildSignHash();

    auto k1 = std::make_tuple(std::string("rs_r"), recSig.getLlmqType(), recSig.getId());
    auto k2 = std::make_tuple(std::string("rs_r"), recSig.getLlmqType(), recSig.getId(), recSig.getMsgHash());
    auto k3 = std::make_tuple(std::string("rs_h"), recSig.GetHash());
    auto k4 = std::make_tuple(std::string("rs_s"), signHash.Get());
    batch.Erase(k1);
    batch.Erase(k2);
    if (deleteHashKey) {
        batch.Erase(k3);
        batch.Erase(k4);
    }

    if (deleteTimeKey) {
        CDataStream writeTimeDs(SER_DISK, CLIENT_VERSION);
        // TODO remove the size() == sizeof(uint32_t) in a future version (when we stop supporting upgrades from < 0.14.1)
        if (db->ReadDataStream(k2, writeTimeDs) && writeTimeDs.size() == sizeof(uint32_t)) {
            uint32_t writeTime;
            writeTimeDs >> writeTime;
            auto k5 = std::make_tuple(std::string("rs_t"), (uint32_t) htobe32_internal(writeTime), recSig.getLlmqType(), recSig.getId());
            batch.Erase(k5);
        }
    }

    LOCK(cs_cache);
    hasSigForIdCache.erase(std::make_pair(recSig.getLlmqType(), recSig.getId()));
    if (deleteHashKey) {
        hasSigForSessionCache.erase(signHash.Get());
        hasSigForHashCache.erase(recSig.GetHash());
    }
}

// Remove the recovered sig itself and all keys required to get from id -> recSig
// This will leave the byHash and signHash key in-place so that HasRecoveredSigForHash /
// late-share filtering still returns true
void CRecoveredSigsDb::TruncateRecoveredSig(Consensus::LLMQType llmqType, const uint256& id)
{
    CDBBatch batch(*db);
    RemoveRecoveredSig(batch, llmqType, id, false, false);
    db->WriteBatch(batch);
}

void CRecoveredSigsDb::CleanupOldRecoveredSigs(int64_t maxAge)
{
    std::unique_ptr<CDBIterator> pcursor(db->NewIterator());

    auto start = std::make_tuple(std::string("rs_t"), (uint32_t)0, (Consensus::LLMQType)0, uint256());
    uint32_t endTime = (uint32_t)(GetTime<std::chrono::seconds>().count() - maxAge);
    pcursor->Seek(start);

    std::vector<std::pair<Consensus::LLMQType, uint256>> toDelete;
    std::vector<decltype(start)> toDelete2;

    while (pcursor->Valid()) {
        decltype(start) k;

        if (!pcursor->GetKey(k) || std::get<0>(k) != "rs_t") {
            break;
        }
        if (be32toh_internal(std::get<1>(k)) >= endTime) {
            break;
        }

        toDelete.emplace_back(std::get<2>(k), std::get<3>(k));
        toDelete2.emplace_back(k);

        pcursor->Next();
    }
    pcursor.reset();

    if (toDelete.empty()) {
        return;
    }

    CDBBatch batch(*db);
    for (const auto& e : toDelete) {
        RemoveRecoveredSig(batch, e.first, e.second, true, false);

        if (batch.SizeEstimate() >= (1 << 24)) {
            db->WriteBatch(batch);
            batch.Clear();
        }
    }

    for (const auto& e : toDelete2) {
        batch.Erase(e);
    }

    db->WriteBatch(batch);

    LogPrint(BCLog::LLMQ, "CRecoveredSigsDb::%d -- deleted %d entries\n", __func__, toDelete.size());
}

bool CRecoveredSigsDb::HasVotedOnId(Consensus::LLMQType llmqType, const uint256& id) const
{
    auto k = std::make_tuple(std::string("rs_v"), llmqType, id);
    return db->Exists(k);
}

bool CRecoveredSigsDb::GetVoteForId(Consensus::LLMQType llmqType, const uint256& id, uint256& msgHashRet) const
{
    auto k = std::make_tuple(std::string("rs_v"), llmqType, id);
    return db->Read(k, msgHashRet);
}

void CRecoveredSigsDb::WriteVoteForId(Consensus::LLMQType llmqType, const uint256& id, const uint256& msgHash)
{
    auto k1 = std::make_tuple(std::string("rs_v"), llmqType, id);
    auto k2 = std::make_tuple(std::string("rs_vt"), (uint32_t)htobe32_internal(GetTime<std::chrono::seconds>().count()), llmqType, id);

    CDBBatch batch(*db);
    batch.Write(k1, msgHash);
    batch.Write(k2, (uint8_t)1);

    db->WriteBatch(batch);
}

void CRecoveredSigsDb::CleanupOldVotes(int64_t maxAge)
{
    std::unique_ptr<CDBIterator> pcursor(db->NewIterator());

    auto start = std::make_tuple(std::string("rs_vt"), (uint32_t)0, (Consensus::LLMQType)0, uint256());
    uint32_t endTime = (uint32_t)(GetTime<std::chrono::seconds>().count() - maxAge);
    pcursor->Seek(start);

    CDBBatch batch(*db);
    size_t cnt = 0;
    while (pcursor->Valid()) {
        decltype(start) k;

        if (!pcursor->GetKey(k) || std::get<0>(k) != "rs_vt") {
            break;
        }
        if (be32toh_internal(std::get<1>(k)) >= endTime) {
            break;
        }

        Consensus::LLMQType llmqType = std::get<2>(k);
        const uint256& id = std::get<3>(k);

        batch.Erase(k);
        batch.Erase(std::make_tuple(std::string("rs_v"), llmqType, id));

        cnt++;

        pcursor->Next();
    }
    pcursor.reset();

    if (cnt == 0) {
        return;
    }

    db->WriteBatch(batch);

    LogPrint(BCLog::LLMQ, "CRecoveredSigsDb::%d -- deleted %d entries\n", __func__, cnt);
}

//////////////////

CSigningManager::CSigningManager(const CQuorumManager& _qman, const util::DbWrapperParams& db_params) :
    db{db_params},
    qman{_qman}
{
}

CSigningManager::~CSigningManager() = default;

bool CSigningManager::AlreadyHave(const CInv& inv) const
{
    if (inv.type != MSG_QUORUM_RECOVERED_SIG) {
        return false;
    }
    {
        LOCK(cs_pending);
        if (pendingReconstructedRecoveredSigs.count(inv.hash)) {
            return true;
        }
    }

    return db.HasRecoveredSigForHash(inv.hash);
}

bool CSigningManager::GetRecoveredSigForGetData(const uint256& hash, CRecoveredSig& ret) const
{
    if (!db.GetRecoveredSigByHash(hash, ret)) {
        return false;
    }
    if (!IsQuorumActive(ret.getLlmqType(), qman, ret.getQuorumHash())) {
        // we don't want to propagate sigs from inactive quorums
        return false;
    }
    return true;
}

void CSigningManager::VerifyAndProcessRecoveredSig(NodeId from, std::shared_ptr<CRecoveredSig> recoveredSig)
{
    auto llmq_type = recoveredSig->getLlmqType();
    auto quorum = qman.GetQuorum(llmq_type, recoveredSig->getQuorumHash());

    if (!quorum) {
        LogPrint(BCLog::LLMQ, "CSigningManager::%s -- quorum %s not found\n", __func__,
                 recoveredSig->getQuorumHash().ToString());
        return;
    }
    if (!IsQuorumActive(llmq_type, qman, quorum->qc->quorumHash)) {
        return;
    }

    // It's important to only skip seen *valid* sig shares here. See comment for CBatchedSigShare
    // We don't receive recovered sigs in batches, but we do batched verification per node on these
    if (db.HasRecoveredSigForHash(recoveredSig->GetHash())) {
        return;
    }

    LogPrint(BCLog::LLMQ, "CSigningManager::%s -- signHash=%s, id=%s, msgHash=%s, node=%d\n", __func__,
             recoveredSig->buildSignHash().ToString(), recoveredSig->getId().ToString(), recoveredSig->getMsgHash().ToString(), from);

    LOCK(cs_pending);
    if (pendingReconstructedRecoveredSigs.count(recoveredSig->GetHash())) {
        // no need to perform full verification
        LogPrint(BCLog::LLMQ, "CSigningManager::%s -- already pending reconstructed sig, signHash=%s, id=%s, msgHash=%s, node=%d\n", __func__,
                 recoveredSig->buildSignHash().ToString(), recoveredSig->getId().ToString(), recoveredSig->getMsgHash().ToString(), from);
        return;
    }

    pendingRecoveredSigs[from].emplace_back(std::move(recoveredSig));
}

bool CSigningManager::CollectPendingRecoveredSigsToVerify(
    size_t maxUniqueSessions, std::unordered_map<NodeId, std::list<std::shared_ptr<const CRecoveredSig>>>& retSigShares,
    std::unordered_map<std::pair<Consensus::LLMQType, uint256>, CBLSPublicKey, StaticSaltedHasher>& ret_pubkeys)
{
    bool more_work{false};

    {
        LOCK(cs_pending);
        if (pendingRecoveredSigs.empty()) {
            return false;
        }

        // TODO: refactor it to remove duplicated code with `CSigSharesManager::CollectPendingSigSharesToVerify`
        std::unordered_set<std::pair<NodeId, uint256>, StaticSaltedHasher> uniqueSignHashes;
        IterateNodesRandom(pendingRecoveredSigs, [&]() {
            return uniqueSignHashes.size() < maxUniqueSessions;
        }, [&](NodeId nodeId, std::list<std::shared_ptr<const CRecoveredSig>>& ns) {
            if (ns.empty()) {
                return false;
            }
            auto& recSig = *ns.begin();

            bool alreadyHave = db.HasRecoveredSigForHash(recSig->GetHash());
            if (!alreadyHave) {
                uniqueSignHashes.emplace(nodeId, recSig->buildSignHash().Get());
                retSigShares[nodeId].emplace_back(recSig);
            }
            ns.erase(ns.begin());
            return !ns.empty();
        }, rnd);

        if (retSigShares.empty()) {
            return false;
        }

        more_work = std::any_of(pendingRecoveredSigs.begin(), pendingRecoveredSigs.end(),
                                [](const auto& p) { return !p.second.empty(); }) ||
                    !pendingReconstructedRecoveredSigs.empty();
    }

    for (auto& [nodeId, v] : retSigShares) {
        for (auto it = v.begin(); it != v.end();) {
            const auto& recSig = *it;

            auto llmqType = recSig->getLlmqType();
            auto quorumKey = std::make_pair(recSig->getLlmqType(), recSig->getQuorumHash());
            if (!ret_pubkeys.count(quorumKey)) {
                auto quorum = qman.GetQuorum(llmqType, recSig->getQuorumHash());
                if (!quorum) {
                    LogPrint(BCLog::LLMQ, "CSigningManager::%s -- quorum %s not found, node=%d\n", __func__,
                              recSig->getQuorumHash().ToString(), nodeId);
                    it = v.erase(it);
                    continue;
                }
                if (!IsQuorumActive(llmqType, qman, quorum->qc->quorumHash)) {
                    LogPrint(BCLog::LLMQ, "CSigningManager::%s -- quorum %s not active anymore, node=%d\n", __func__,
                              recSig->getQuorumHash().ToString(), nodeId);
                    it = v.erase(it);
                    continue;
                }

                ret_pubkeys.emplace(quorumKey, quorum->qc->quorumPublicKey);
            }

            ++it;
        }
    }

    return more_work;
}

Uint256HashMap<std::shared_ptr<const CRecoveredSig>> CSigningManager::FetchPendingReconstructed()
{
    Uint256HashMap<std::shared_ptr<const CRecoveredSig>> tmp;
    WITH_LOCK(cs_pending, swap(tmp, pendingReconstructedRecoveredSigs));
    return tmp;
}

// signature must be verified already
bool CSigningManager::ProcessRecoveredSig(const std::shared_ptr<const CRecoveredSig>& recoveredSig)
{
    auto llmqType = recoveredSig->getLlmqType();

    if (db.HasRecoveredSigForHash(recoveredSig->GetHash())) {
        return false;
    }

    auto signHash = recoveredSig->buildSignHash();

    LogPrint(BCLog::LLMQ, "CSigningManager::%s -- valid recSig. signHash=%s, id=%s, msgHash=%s\n", __func__,
            signHash.ToString(), recoveredSig->getId().ToString(), recoveredSig->getMsgHash().ToString());

    if (db.HasRecoveredSigForId(llmqType, recoveredSig->getId())) {
        CRecoveredSig otherRecoveredSig;
        if (db.GetRecoveredSigById(llmqType, recoveredSig->getId(), otherRecoveredSig)) {
            auto otherSignHash = otherRecoveredSig.buildSignHash();
            if (signHash.Get() != otherSignHash.Get()) {
                // this should really not happen, as each masternode is participating in only one vote,
                // even if it's a member of multiple quorums. so a majority is only possible on one quorum and one msgHash per id
                LogPrintf("CSigningManager::%s -- conflicting recoveredSig for signHash=%s, id=%s, msgHash=%s, otherSignHash=%s\n", __func__,
                          signHash.ToString(), recoveredSig->getId().ToString(), recoveredSig->getMsgHash().ToString(), otherSignHash.ToString());
            } else {
                // Looks like we're trying to process a recSig that is already known. This might happen if the same
                // recSig comes in through regular QRECSIG messages and at the same time through some other message
                // which allowed to reconstruct a recSig (e.g. ISLOCK). In this case, just bail out.
            }
            return false;
        } else {
            // This case is very unlikely. It can only happen when cleanup caused this specific recSig to vanish
            // between the HasRecoveredSigForId and GetRecoveredSigById call. If that happens, treat it as if we
            // never had that recSig
        }
    }

    db.WriteRecoveredSig(*recoveredSig);
    WITH_LOCK(cs_pending, pendingReconstructedRecoveredSigs.erase(recoveredSig->GetHash()));

    return true;
}

std::vector<CRecoveredSigsListener*> CSigningManager::GetListeners() const
{
    LOCK(cs_listeners);
    return recoveredSigsListeners;
}

void CSigningManager::PushReconstructedRecoveredSig(const std::shared_ptr<const llmq::CRecoveredSig>& recoveredSig)
{
    LOCK(cs_pending);
    pendingReconstructedRecoveredSigs.emplace(std::piecewise_construct, std::forward_as_tuple(recoveredSig->GetHash()), std::forward_as_tuple(recoveredSig));
}

void CSigningManager::TruncateRecoveredSig(Consensus::LLMQType llmqType, const uint256& id)
{
    db.TruncateRecoveredSig(llmqType, id);
}

void CSigningManager::Cleanup()
{
    int64_t maxAge = gArgs.GetIntArg("-maxrecsigsage", DEFAULT_MAX_RECOVERED_SIGS_AGE);

    db.CleanupOldRecoveredSigs(maxAge);
    db.CleanupOldVotes(maxAge);
}

void CSigningManager::RegisterRecoveredSigsListener(CRecoveredSigsListener* l)
{
    LOCK(cs_listeners);
    recoveredSigsListeners.emplace_back(l);
}

void CSigningManager::UnregisterRecoveredSigsListener(CRecoveredSigsListener* l)
{
    LOCK(cs_listeners);
    auto itRem = std::remove(recoveredSigsListeners.begin(), recoveredSigsListeners.end(), l);
    recoveredSigsListeners.erase(itRem, recoveredSigsListeners.end());
}

bool CSigningManager::HasRecoveredSig(Consensus::LLMQType llmqType, const uint256& id, const uint256& msgHash) const
{
    return db.HasRecoveredSig(llmqType, id, msgHash);
}

bool CSigningManager::HasRecoveredSigForId(Consensus::LLMQType llmqType, const uint256& id) const
{
    return db.HasRecoveredSigForId(llmqType, id);
}

bool CSigningManager::HasRecoveredSigForSession(const uint256& signHash) const
{
    return db.HasRecoveredSigForSession(signHash);
}

bool CSigningManager::GetRecoveredSigForId(Consensus::LLMQType llmqType, const uint256& id, llmq::CRecoveredSig& retRecSig) const
{
    if (!db.GetRecoveredSigById(llmqType, id, retRecSig)) {
        return false;
    }
    return true;
}

bool CSigningManager::IsConflicting(Consensus::LLMQType llmqType, const uint256& id, const uint256& msgHash) const
{
    if (!db.HasRecoveredSigForId(llmqType, id)) {
        // no recovered sig present, so no conflict
        return false;
    }

    if (!db.HasRecoveredSig(llmqType, id, msgHash)) {
        // recovered sig is present, but not for the given msgHash. That's a conflict!
        return true;
    }

    // all good
    return false;
}

bool CSigningManager::GetVoteForId(Consensus::LLMQType llmqType, const uint256& id, uint256& msgHashRet) const
{
    return db.GetVoteForId(llmqType, id, msgHashRet);
}

SignHash CSigBase::buildSignHash() const { return SignHash(llmqType, quorumHash, id, msgHash); }


bool IsQuorumActive(Consensus::LLMQType llmqType, const CQuorumManager& qman, const uint256& quorumHash)
{
    // sig shares and recovered sigs are only accepted from recent/active quorums
    // we allow one more active quorum as specified in consensus, as otherwise there is a small window where things could
    // fail while we are on the brink of a new quorum
    const auto& llmq_params_opt = Params().GetLLMQ(llmqType);
    assert(llmq_params_opt.has_value());
    auto quorums = qman.ScanQuorums(llmqType, llmq_params_opt->keepOldConnections);
    return ranges::any_of(quorums, [&quorumHash](const auto& q){ return q->qc->quorumHash == quorumHash; });
}

} // namespace llmq
